package com.bayesianWeka;

import weka.classifiers.AbstractClassifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;

import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.HashSet;

/**
 * @Author: Decade
 * @Date: 2021/2/5 14:42
 */
public class TAN extends AbstractClassifier {

    /*类标签的个数*/
    private int m_NumClasses;

    /*属性的个数，在weka中，本值为预测属性和类属性之和*/
    private int m_NumAttributes;

    /*训练集类标签的索引*/
    private int m_ClassIndex;

    /** TAN结构的父结点 */
    private int[] m_parents;

    /*通用工具类*/
    private CommonUtils commonUtils;

    /**
     * 训练阶段，本阶段主要做两件事情：1. 变量的初始化 数组空间的申请 2. General版本模型分类器的构建
     * @param instances 测试集对象
     */
    @Override
    public void buildClassifier(Instances instances) throws Exception {

        commonUtils = new CommonUtils(instances);

        m_ClassIndex = commonUtils.getM_ClassIndex();          //类标签的索引
        m_NumAttributes = commonUtils.getM_NumAttributes();    //属性个数 算上类标签的个数
        m_NumClasses = commonUtils.getM_NumClasses();          //类标签的个数
        m_parents = new int[m_NumAttributes];           //父结点数组申请空间
        //初始化为-1
        Arrays.fill(m_parents, -1);
//        instances.classIndex();
        double[][] m_CmiMetrix = commonUtils.getConditionMutualInfor(); //条件互信息矩阵
//        System.out.println("m_ClassIndex的值为："+m_ClassIndex);
/**       System.out.println("实例个数为："+m_NumInstances);
 */

//        DecimalFormat df = new DecimalFormat("0.00000");
//        System.out.println("----------------------------");
//        for (int att1 = 0; att1 < m_NumAttributes; att1++) {
//            for (int att2 = 0; att2 < m_NumAttributes; att2++) {
//                if (att1 != m_ClassIndex && att2 != m_ClassIndex)
//                    System.out.print(df.format(m_CmiMetrix[att1][att2]) + "   ");
//            }
//            System.out.println();
//        }
//        System.out.println("----------------------------");

        //申请空间和初始化的工作已经完毕 下面就是TAN的逻辑了 我用Prim算法写的
        double maxCmi = -Double.MAX_VALUE;
        int FirstAtt = 0;       //第一个入图的结点
        int topCadidate = 0;    //最大的候选结点
        HashSet<Integer> available = new HashSet<Integer>();
        HashSet<Integer> inMap = new HashSet<Integer>();          //来两个集合，用来存放入图的结点和未入图的结点
        int[] assumingparents = new int[m_NumAttributes];

        if(m_ClassIndex == 0) {  //一般来说进不了这个判断条件,如果第0个位置正好就是类标签 那我们从第一个位置开始玩
            inMap.add(FirstAtt + 1);
            for (int i = 2; i < m_NumAttributes; i++) {
                available.add(i);
            }
        } else{         //正常情况下类标签都是最后一个  就这么写
            for (int i = 1; i < m_NumAttributes; i++) {
                if(i == m_ClassIndex){
                    continue;
                }
                available.add(i);
            }
            inMap.add(FirstAtt);
        }

        while(!available.isEmpty()){
//            Iterator<Integer> itavailable = available.iterator();
//            Iterator<Integer> itinMap = inMap.iterator();
//            while (itavailable.hasNext()){
//                while (itinMap.hasNext()){
            for(int iterator_available:available){
//                    int iterator_available = itavailable.next();
                for(int iterator_inmap:inMap){
//                    int iterator_inmap = itinMap.next();

                    if(m_CmiMetrix[iterator_available][iterator_inmap]>maxCmi){
                        maxCmi = m_CmiMetrix[iterator_available][iterator_inmap];
                        topCadidate = iterator_available;
                        assumingparents[topCadidate] = iterator_inmap;
                    }
                }
            }
//            System.out.println("结束");
            m_parents[topCadidate] = assumingparents[topCadidate];
            available.remove(topCadidate);
            inMap.add(topCadidate);
            maxCmi = -Double.MAX_VALUE;
        }
//        commonUtils.Count3XAnd4X();
//        double[][][][] xxxxEntropyInf = commonUtils.getXXXXEntropyInf();
//        for (int i = 0; i < xxxxEntropyInf.length; i++) {
//            for (int j = 0; j < xxxxEntropyInf[i].length; j++) {
//                for (int k = 0; k < xxxxEntropyInf[i][j].length; k++) {
//                    for (int l = 0; l < xxxxEntropyInf[i][j][k].length; l++) {
//                        System.out.print(xxxxEntropyInf[i][j][j][k] + "     ");
//                    }
//                }
//            }
//        }
//        System.out.println("****************");
    }

    /**
     * 分类/决策阶段，本阶段主要做两件事：
     * 1. 给定测试样本，计算联合概率，并返回归一化后的联合概率(即后验概率)，用于分类器做决策
     * 2. Local版本分类器的构建
     * @param instance 一个测试样本实例
     * @return 后验概率
     */
    public double[] distributionForInstance(Instance instance) throws Exception{

        double[] probs = new double[m_NumClasses];
        for (int i = 0; i < m_NumClasses; i++) {
            probs[i] = commonUtils.P(i);        //P(y)
        }

        for (int x1 = 0; x1 < m_NumAttributes; x1++) {
            if(x1 == m_ClassIndex)
                continue;
            int parent = m_parents[x1];
            if(parent == -1){
                for (int y = 0; y < m_NumClasses; y++) {    //p(x0|y) = (F(x0,y)+1/F(x0)) / (1+F(y))
                    probs[y] *= commonUtils.P(instance,x1,y);
                }
            } else{
                for (int y = 0; y < m_NumClasses; y++) {    //P(X0=x0|X1=x1,Y=y) = (F(x0,x1,y) + 1/F(x0)) / (1+F(x1,y))
                    probs[y] *= commonUtils.P(instance,x1,parent,y);
                }
            }
        }

        Utils.normalize(probs);
//        System.out.println(Arrays.toString(probs));
        return probs;
    }

    /**
     * 主方法 用于运行分类器 runClassifier(new KDB(), args);语句中的KDB()要与分类器名字(即本java文件的主类名字)一一对应
     * @param argv 命令行参数
     */
    public static void main(String[] argv) {
        runClassifier(new TAN(), argv);
    }
}
